/*******************************************************************************
 * Package: com.song.sql
 * Type:    MyAvg
 * Date:    2024-12-15 19:50
 *
 * Copyright (c) 2024 LTD All Rights Reserved.
 *
 * You may not use this file except in compliance with the License.
 *******************************************************************************/
package com.song.sql;


import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;

/**
 * 功能描述：
 *
 * @author Songxianyang
 * @date 2024-12-15 19:50
 */
public class MyAvg extends Aggregator<Long,BufferObj,Double> {
    private static final long serialVersionUID = 5265883557744100945L;
    // 初始化
    @Override
    public BufferObj zero() {
        return new BufferObj(0L,0L);
    }
    // 聚合
    @Override
    public BufferObj reduce(BufferObj b, Long a) {
        b.setSum(b.getSum() + a);
        b.setCount(b.getCount() + 1);
        return b;
    }
    //合并分布式结果缓冲区
    @Override
    public BufferObj merge(BufferObj b1, BufferObj b2) {
        b1.setCount(b1.getCount() + b2.getCount());
        b1.setSum(b1.getSum()+ b2.getSum());
        return b1;
    }

    // 刷新结果
    @Override
    public Double finish(BufferObj reduction) {
        return reduction.getSum().doubleValue()/reduction.getCount();
    }

    @Override
    public Encoder<BufferObj> bufferEncoder() {
        return Encoders.kryo(BufferObj.class);
    }

    @Override
    public Encoder<Double> outputEncoder() {
        return Encoders.DOUBLE();
    }
}
